import asyncio
from typing import Dict

import pytest
import pytest_asyncio
from pydantic.v1 import BaseModel, StrictInt, conint

from hivemind import DHT
from hivemind.dht.node import DHTNode
from hivemind.dht.schema import BytesWithPublicKey, SchemaValidator
from hivemind.dht.validation import DHTRecord, RecordValidatorBase
from hivemind.utils.timed_storage import get_dht_time


class SampleSchema(BaseModel):
    experiment_name: bytes
    n_batches: Dict[bytes, conint(ge=0, strict=True)]
    signed_data: Dict[BytesWithPublicKey, bytes]


@pytest_asyncio.fixture
async def dht_nodes_with_schema():
    validator = SchemaValidator(SampleSchema)

    alice = await DHTNode.create(record_validator=validator)
    bob = await DHTNode.create(record_validator=validator, initial_peers=await alice.get_visible_maddrs())
    yield alice, bob

    await asyncio.gather(alice.shutdown(), bob.shutdown())


@pytest.mark.forked
@pytest.mark.asyncio
async def test_expecting_regular_value(dht_nodes_with_schema):
    alice, bob = dht_nodes_with_schema

    # Regular value (bytes) expected
    assert await bob.store("experiment_name", b"foo_bar", get_dht_time() + 10)
    assert not await bob.store("experiment_name", 666, get_dht_time() + 10)
    assert not await bob.store("experiment_name", b"foo_bar", get_dht_time() + 10, subkey=b"subkey")

    # Refuse records despite https://pydantic-docs.helpmanual.io/usage/models/#data-conversion
    assert not await bob.store("experiment_name", [], get_dht_time() + 10)
    assert not await bob.store("experiment_name", [1, 2, 3], get_dht_time() + 10)

    for peer in [alice, bob]:
        assert (await peer.get("experiment_name", latest=True)).value == b"foo_bar"


@pytest.mark.forked
@pytest.mark.asyncio
async def test_expecting_dictionary(dht_nodes_with_schema):
    alice, bob = dht_nodes_with_schema

    # Dictionary (bytes -> non-negative int) expected
    assert await bob.store("n_batches", 777, get_dht_time() + 10, subkey=b"uid1")
    assert await bob.store("n_batches", 778, get_dht_time() + 10, subkey=b"uid2")
    assert not await bob.store("n_batches", -666, get_dht_time() + 10, subkey=b"uid3")
    assert not await bob.store("n_batches", 666, get_dht_time() + 10)
    assert not await bob.store("n_batches", b"not_integer", get_dht_time() + 10, subkey=b"uid1")
    assert not await bob.store("n_batches", 666, get_dht_time() + 10, subkey=666)

    # Refuse storing a plain dictionary bypassing the DictionaryDHTValue convention
    assert not await bob.store("n_batches", {b"uid3": 779}, get_dht_time() + 10)

    # Refuse records despite https://pydantic-docs.helpmanual.io/usage/models/#data-conversion
    assert not await bob.store("n_batches", 779.5, get_dht_time() + 10, subkey=b"uid3")
    assert not await bob.store("n_batches", 779.0, get_dht_time() + 10, subkey=b"uid3")
    assert not await bob.store("n_batches", [], get_dht_time() + 10)
    assert not await bob.store("n_batches", [(b"uid3", 779)], get_dht_time() + 10)

    # Refuse records despite https://github.com/samuelcolvin/pydantic/issues/1268
    assert not await bob.store("n_batches", "", get_dht_time() + 10)

    for peer in [alice, bob]:
        dictionary = (await peer.get("n_batches", latest=True)).value
        assert len(dictionary) == 2 and dictionary[b"uid1"].value == 777 and dictionary[b"uid2"].value == 778


@pytest.mark.forked
@pytest.mark.asyncio
async def test_expecting_public_keys(dht_nodes_with_schema):
    alice, bob = dht_nodes_with_schema

    # Subkeys expected to contain a public key
    # (so hivemind.dht.crypto.RSASignatureValidator would require a signature)
    assert await bob.store("signed_data", b"foo_bar", get_dht_time() + 10, subkey=b"uid[owner:public-key]")
    assert not await bob.store("signed_data", b"foo_bar", get_dht_time() + 10, subkey=b"uid-without-public-key")

    for peer in [alice, bob]:
        dictionary = (await peer.get("signed_data", latest=True)).value
        assert len(dictionary) == 1 and dictionary[b"uid[owner:public-key]"].value == b"foo_bar"


@pytest.mark.forked
@pytest.mark.asyncio
async def test_keys_outside_schema(dht_nodes_with_schema):
    class Schema(BaseModel):
        some_field: StrictInt

    class MergedSchema(BaseModel):
        another_field: StrictInt

    for allow_extra_keys in [False, True]:
        validator = SchemaValidator(Schema, allow_extra_keys=allow_extra_keys)
        assert validator.merge_with(SchemaValidator(MergedSchema, allow_extra_keys=False))

        alice = await DHTNode.create(record_validator=validator)
        bob = await DHTNode.create(record_validator=validator, initial_peers=await alice.get_visible_maddrs())

        store_ok = await bob.store("unknown_key", b"foo_bar", get_dht_time() + 10)
        assert store_ok == allow_extra_keys

        for peer in [alice, bob]:
            result = await peer.get("unknown_key", latest=True)
            if allow_extra_keys:
                assert result.value == b"foo_bar"
            else:
                assert result is None


@pytest.mark.forked
@pytest.mark.asyncio
async def test_prefix():
    class Schema(BaseModel):
        field: StrictInt

    validator = SchemaValidator(Schema, allow_extra_keys=False, prefix="prefix")

    alice = await DHTNode.create(record_validator=validator)
    bob = await DHTNode.create(record_validator=validator, initial_peers=await alice.get_visible_maddrs())

    assert await bob.store("prefix_field", 777, get_dht_time() + 10)
    assert not await bob.store("prefix_field", "string_value", get_dht_time() + 10)
    assert not await bob.store("field", 777, get_dht_time() + 10)

    for peer in [alice, bob]:
        assert (await peer.get("prefix_field", latest=True)).value == 777
        assert (await peer.get("field", latest=True)) is None

    await asyncio.gather(alice.shutdown(), bob.shutdown())


@pytest.mark.forked
@pytest.mark.asyncio
async def test_merging_schema_validators(dht_nodes_with_schema):
    alice, bob = dht_nodes_with_schema

    class TrivialValidator(RecordValidatorBase):
        def validate(self, record: DHTRecord) -> bool:
            return True

    second_validator = TrivialValidator()
    # Can't merge with the validator of the different type
    assert not alice.protocol.record_validator.merge_with(second_validator)

    class SecondSchema(BaseModel):
        some_field: StrictInt
        another_field: str

    class ThirdSchema(BaseModel):
        another_field: StrictInt  # Allow it to be a StrictInt as well

    for schema in [SecondSchema, ThirdSchema]:
        new_validator = SchemaValidator(schema, allow_extra_keys=False)
        for peer in [alice, bob]:
            assert peer.protocol.record_validator.merge_with(new_validator)

    assert await bob.store("experiment_name", b"foo_bar", get_dht_time() + 10)
    assert await bob.store("some_field", 777, get_dht_time() + 10)
    assert not await bob.store("some_field", "string_value", get_dht_time() + 10)
    assert await bob.store("another_field", 42, get_dht_time() + 10)
    assert await bob.store("another_field", "string_value", get_dht_time() + 10)

    # Unknown keys are allowed since the first schema is created with allow_extra_keys=True
    assert await bob.store("unknown_key", 999, get_dht_time() + 10)

    for peer in [alice, bob]:
        assert (await peer.get("experiment_name", latest=True)).value == b"foo_bar"
        assert (await peer.get("some_field", latest=True)).value == 777
        assert (await peer.get("another_field", latest=True)).value == "string_value"

        assert (await peer.get("unknown_key", latest=True)).value == 999


@pytest.mark.forked
def test_sending_validator_instance_between_processes():
    alice = DHT(start=True)
    bob = DHT(start=True, initial_peers=alice.get_visible_maddrs())

    alice.add_validators([SchemaValidator(SampleSchema)])
    bob.add_validators([SchemaValidator(SampleSchema)])

    assert bob.store("experiment_name", b"foo_bar", get_dht_time() + 10)
    assert not bob.store("experiment_name", 777, get_dht_time() + 10)
    assert alice.get("experiment_name", latest=True).value == b"foo_bar"

    alice.shutdown()
    bob.shutdown()
